from google.colab import drive
drive.mount('/content/gdrive')
import numpy as np
import cv2
import imageio
import glob
import matplotlib.pyplot as plt
from tensorflow import keras
from tensorflow.keras import layers
from scipy import ndimage
path = r'/content/gdrive/My Drive/LOLdataset/our485/high'
all_files = glob.glob(path + "/*.png")
high = list()
all_files.sort()
for fileName in all_files:
img = imageio.imread(fileName)
high.append(img)
high = np.array(high)
path = r'/content/gdrive/My Drive/LOLdataset/our485/low'
all_files = glob.glob(path + "/*.png")
all_files.sort()
low = list()
for fileName in all_files:
img = imageio.imread(fileName)
low.append(img)
low = np.array(low)
plt.figure(figsize=(30,30))
plt.subplot(1,2,1)
plt.title("Ground Truth",fontsize=20)
plt.imshow(high[479])
plt.subplot(1,2,2)
plt.title("Low Light Image",fontsize=20)
plt.imshow(low[479])
We employ a plain CNN of seven convolutional layers with symmetrical concatenation. Each layer consists of 32 convolutional kernels of size 3x3 and stride 1 followed by the ReLU activation function. The last convolutional layer is followed by the sigmoid activation function which produce parameter map of pixel range [0:1]
inputs = keras.Input(shape=(None, None, 3), name='img')
out1 = layers.Conv2D(32, (3,3), activation='relu', padding="same", strides=(1, 1))(inputs)
out2 = layers.Conv2D(32, (3,3), activation='relu', padding="same", strides=(1, 1))(out1)
out3 = layers.Conv2D(32, (3,3), activation='relu', padding="same", strides=(1, 1))(out2)
out4 = layers.Conv2D(32, (3,3), activation='relu', padding="same", strides=(1, 1))(out3)
in5 = layers.add([out3, out4])
out5 = layers.Conv2D(32, (3,3), activation='relu', padding="same", strides=(1, 1))(in5)
in6 = layers.add([out2, out5])
out6 = layers.Conv2D(32, (3,3), activation='relu', padding="same", strides=(1, 1))(in6)
in7 = layers.add([out1, out6])
outputs = layers.Conv2D(1, (3,3), activation='sigmoid', padding="same", strides=(1, 1))(in7)
model = keras.Model(inputs, outputs)
model.compile(optimizer="adam", loss='mean_squared_error')
model.summary()
keras.utils.plot_model(model, 'mini_resnet.png', show_shapes=True)
def GenerateInputs(X,y):
for i in range(len(X)):
X_input = X[i].reshape(1,400,600,3)
y_temp = y[i] / 255
y_input = y_temp.reshape(1,400,600,3)
yield (X_input,y_input)
model.fit(GenerateInputs(low,high), epochs=20, verbose=1, steps_per_epoch=24)
model.save(r'/content/gdrive/My Drive/ImageProcessing/ourModel.h5')
path = r'/content/gdrive/My Drive/LOLdataset/eval15/low'
all_files = glob.glob(path + "/*")
x = list()
all_files.sort()
for fileName in all_files:
img = imageio.imread(fileName)
x.append(img)
X = np.array(x)
inputs:
img: the low light image (I, Eold)
index: no of iterations
flag: default = 1
for more info. see README.md
def Enhance(img, index, flag):
if index == 0:
return img
elif flag == 1:
h, w, c = img.shape
test = model.predict(img.reshape(1, h, w, 3))
temp = img / 255
image = temp + ((test[0,:,:,:] * temp)*(1-temp))
index = index - 1
flag = 0
return Enhance(image, index, flag)
else:
h, w, c = img.shape
temp = model.predict(img.reshape(1, h, w, 3))
image = img + ((temp[0,:,:,:] * img)*(1-img))
index = index - 1
return Enhance(image, index, flag)
Image = X[1]
plt.figure(figsize=(30,30))
plt.subplot(1,2,1)
plt.title("Low Light Image",fontsize=20)
plt.imshow(Image)
plt.subplot(1,2,2)
plt.title("Enhanced Image",fontsize=20)
plt.imshow(Enhance(Image, 8, 1))
Image = X[2]
plt.figure(figsize=(30,30))
plt.subplot(1,2,1)
plt.title("Low Light Image",fontsize=20)
plt.imshow(Image)
plt.subplot(1,2,2)
plt.title("Enhanced Image",fontsize=20)
plt.imshow(Enhance(Image, 8, 1))
Image = X[6]
plt.figure(figsize=(30,30))
plt.subplot(1,2,1)
plt.title("Low Light Image",fontsize=20)
plt.imshow(Image)
plt.subplot(1,2,2)
plt.title("Enhanced Image",fontsize=20)
plt.imshow(Enhance(Image, 12, 1))
path = r'/content/gdrive/My Drive/ImageProcessing/DataSet/test'
all_files = glob.glob(path + "/*")
x = list()
all_files.sort()
for fileName in all_files:
img = imageio.imread(fileName)
x.append(img)
Z = np.array(x)
Image = Z[0]
plt.figure(figsize=(30,30))
plt.subplot(1,2,1)
plt.title("Low Light Image",fontsize=20)
plt.imshow(Image)
plt.subplot(1,2,2)
plt.title("Enhanced Image",fontsize=20)
plt.imshow(Enhance(Image, 12, 1))
Image = Z[3]
plt.figure(figsize=(30,30))
plt.subplot(1,2,1)
plt.title("Low Light Image",fontsize=20)
plt.imshow(Image)
plt.subplot(1,2,2)
plt.title("Enhanced Image",fontsize=20)
plt.imshow(Enhance(Image, 12, 1))